Skip to content

Conversation

@LucasWilkinson
Copy link
Collaborator

When doing prefill up-convert the kv-cache from fp8 to bf16 and call the bf16 prefill kernel instead of the decode kernel. This PR introduce global workspace management to have the bf16 workspace overlap with the MoE workspace buffers.

@mergify mergify bot added deepseek Related to DeepSeek models v1 labels Oct 26, 2025
@LucasWilkinson LucasWilkinson marked this pull request as ready for review October 26, 2025 21:55
Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 125 to 136
def get(self, spec: "WorkspaceSpec") -> torch.Tensor:
"""Get a workspace tensor for the given spec.
Args:
spec: The workspace specification.
Returns:
A tensor view into the workspace buffer with the requested shape and dtype.
"""
num_bytes = spec.num_bytes()
current_workspace = self._ensure_workspace_size(num_bytes, spec.name)
return current_workspace[:num_bytes].view(spec.dtype).reshape(spec.shape)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Allocating workspaces fails due to invalid view call

WorkspaceManager.get reinterprets the byte buffer with current_workspace[:num_bytes].view(spec.dtype) but Tensor.view only accepts a shape, not a dtype. Passing a torch.dtype raises TypeError: 'torch.dtype' object cannot be interpreted as an integer, so every call to reserve/get will crash before returning a workspace. The manager needs to reshape using a size tuple and cast with view(dtype) via reinterpret_cast semantics (e.g. view(-1).view(spec.dtype) or view(dtype).reshape).

Useful? React with 👍 / 👎.

Comment on lines 794 to 805
# Process decode tokens
if num_decode_tokens > 0:
attn_out = self._forward_fp8_kv(
q[:num_decode_tokens],
kv_cache,
topk_indices_global[:num_decode_tokens],
attn_metadata,
)

if num_prefill_tokens > 0:
decode_attn_out = attn_out
attn_out = q.new_empty(

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Badge Prefill-only batches reference attn_out before initialization

In the fp8 path of FlashMLASparseImpl.forward, attn_out is only assigned inside the if num_decode_tokens > 0 branch. The subsequent if num_prefill_tokens > 0 branch unconditionally reads decode_attn_out = attn_out, which raises UnboundLocalError whenever a batch contains only prefill tokens. Prefill batches are common during initial context ingestion, so this path will always fail until attn_out is initialized for the prefill case.

Useful? React with 👍 / 👎.

Comment on lines 803 to 811
if num_prefill_tokens > 0:
decode_attn_out = attn_out
attn_out = q.new_empty(
(num_actual_toks, self.num_heads, self.kv_lora_rank),
dtype=q.dtype,
device=q.device,
)
attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Decode outputs stored into prefill slots

When both decode and prefill tokens exist, the fp8 path copies decode attention results with attn_out[:num_prefill_tokens] = decode_attn_out[:num_prefill_tokens]. Decode tokens occupy the first num_decode_tokens entries, so this writes them into the wrong slice and fails whenever num_prefill_tokens > num_decode_tokens because the right-hand side is shorter than the target. The assignment should use num_decode_tokens to preserve decode outputs and avoid size mismatches.

Useful? React with 👍 / 👎.

@yewentao256 yewentao256 added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 31, 2025
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/upconvert-all-2 branch 2 times, most recently from 7a3b6b6 to 39ba79c Compare November 4, 2025 05:35
None, # Pass None to avoid using sampled token counts
)

current_workspace_manager().get_simultaneous(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have to do these allocation during model execution? Is it possible to setup the memory buffer before real execution to reduce the runtime overhead?
For example, I'm thinking of

  1. during model init, call current_workspace_manager().get_simultaneous() to tell the workspace manager the max possible size the model may used
  2. lock the memory space
  3. during profile run, allocate the memory and save to self.xxx like self.workspace13 = workspace13; self.workspace2 = workspace2
  4. during model execution, just use the self.workspace13

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ya we should do something like this the only complication currently is when self.fused_experts.supports_chunking() == False (i.e. PPLX or DeepEP LL) then we need profile run to know the shape of the workspaces because in that case the profile run actually mirrors the worst case scenario (hence gating this logic).

This code is meant to a straight refactor of #27426 (once that lands) to use the new workspace manager so im partial to leaving this optimization to a future PR if you are cool with it. I am trying to learn more about the MoE chunking code in-order to propose a broader UX refactor there since I find it confusing currently and I think this optimization could be part of that.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah works for me. Can you do some benchmark to ensure no perf regression?

((total_seq_lens, 4), torch.uint8),
)

return sparse_attn_indexer_fake(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should these lines in sparse_attn_indexer_fake be removed? I wrote them to mimic the activation memory during profile run.

_flattened_kv = torch.empty(
        [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
    )
    _k_fp8 = _flattened_kv[..., :head_dim].view(torch.float8_e4m3fn).contiguous()
    _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

theyre back 👍 (not sure when haha; I guess I must have done it and just not pushed it)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to say that if you choose to reserve the buffers of shape (total_seq_lens, head_dim) and (total_seq_lens, 4) in the workspace, you don't need to run the following in sparse_attn_indexer_fake

_flattened_kv = torch.empty(
        [total_seq_lens, head_dim + 4], device=k.device, dtype=torch.uint8
    )

to also reserve memory for these two tensors in activation memory.

@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/upconvert-all-2 branch from 39ba79c to 4a49fc9 Compare November 9, 2025 21:41
@mergify
Copy link

mergify bot commented Nov 10, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @LucasWilkinson.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Nov 10, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

fix

Signed-off-by: Lucas Wilkinson <[email protected]>

fix

Signed-off-by: Lucas Wilkinson <[email protected]>

fix

Signed-off-by: Lucas Wilkinson <[email protected]>

clean-up revert to triton

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

keep

Signed-off-by: Lucas Wilkinson <[email protected]>

fix

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

review comments

Signed-off-by: Lucas Wilkinson <[email protected]>

fixed

Signed-off-by: Lucas Wilkinson <[email protected]>

cleanup

Signed-off-by: Lucas Wilkinson <[email protected]>

minor optimization

Signed-off-by: Lucas Wilkinson <[email protected]>

remove get

Signed-off-by: Lucas Wilkinson <[email protected]>

clean up

Signed-off-by: Lucas Wilkinson <[email protected]>

fix

Signed-off-by: Lucas Wilkinson <[email protected]>
@LucasWilkinson LucasWilkinson force-pushed the lwilkinson/upconvert-all-2 branch from a3f6647 to f665def Compare November 22, 2025 06:02
@mergify mergify bot removed the needs-rebase label Nov 22, 2025
Signed-off-by: Lucas Wilkinson <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

deepseek Related to DeepSeek models ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants